{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `KNN` from scratch! 🚀"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class KNN:\n",
    "    def __init__(self, k=3, task='classification'):\n",
    "        self.k = k\n",
    "        self.task = task\n",
    "\n",
    "    def _euclidean_distance(self, a, b):\n",
    "        # Calculate the Euclidean distance between two points\n",
    "        return np.sqrt(np.sum((a - b)**2, axis=1))\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        # Store training data and labels\n",
    "        self.X_train = X\n",
    "        self.y_train = y\n",
    "\n",
    "    def predict(self, X):\n",
    "        # Predict the class labels or target values for a set of data points\n",
    "        y_pred = [self._predict_single(x) for x in X]\n",
    "        return np.array(y_pred)\n",
    "\n",
    "    def _predict_single(self, x):\n",
    "        # Predict the class label or target value for a single data point\n",
    "        distances = self._euclidean_distance(x, self.X_train)\n",
    "        # Find K closest data points\n",
    "        k_indices = np.argsort(distances)[:self.k]\n",
    "        \n",
    "        # Get nearest neighbours\n",
    "        nn = [self.X_train[i].tolist() for i in k_indices] \n",
    "        print('Nearest_neighbours: ', nn)\n",
    "        \n",
    "        # Get their labels\n",
    "        k_nearest_labels = [self.y_train[i] for i in k_indices]  \n",
    "        print('Labels for Nearest_neighbours: ', k_nearest_labels)\n",
    "        \n",
    "        if self.task == 'classification':\n",
    "            return self._majority_vote(k_nearest_labels)\n",
    "        elif self.task == 'regression':\n",
    "            return self._average(k_nearest_labels)\n",
    "\n",
    "    def _majority_vote(self, labels):\n",
    "        # Determine the majority class label from a list of labels\n",
    "        return np.argmax(np.bincount(labels))\n",
    "\n",
    "    def _average(self, values):\n",
    "        # Calculate the average of a list of values\n",
    "        return np.mean(values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Let's test it for regression & Classification 🚀 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Nearest_neighbours:  [[0, 0], [1, 1], [2, 2]]\n",
      "Labels for Nearest_neighbours:  [0, 0, 1]\n",
      "Predicted labels: [0]\n"
     ]
    }
   ],
   "source": [
    "# Test the KNN implementation\n",
    "if __name__ == \"__main__\":\n",
    "    X_train = np.array([[0, 0], [1, 1], [2, 2], [3, 3]])\n",
    "\n",
    "    # class lavels 👇 \n",
    "    y_train = np.array([0, 0, 1, 1])\n",
    "    \n",
    "    X_test = np.array([[0.5, 0.5]])\n",
    "\n",
    "    knn = KNN(k=3, task='classification')\n",
    "    knn.fit(X_train, y_train)\n",
    "    y_pred = knn.predict(X_test)\n",
    "\n",
    "    print(\"Predicted labels:\", y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Nearest_neighbours:  [[2, 2], [1, 1], [3, 3]]\n",
      "Labels for Nearest_neighbours:  [1, 0, 1]\n",
      "Predicted labels: [0.66666667]\n"
     ]
    }
   ],
   "source": [
    "# Test the KNN implementation\n",
    "if __name__ == \"__main__\":\n",
    "    X_train = np.array([[0, 0], [1, 1], [2, 2], [3, 3]])\n",
    "\n",
    "    # class lavels 👇 \n",
    "    y_train = np.array([0, 0, 1, 1])\n",
    "    \n",
    "    X_test = np.array([[2, 2]])\n",
    "\n",
    "    knn = KNN(k=3, task='regression')\n",
    "    knn.fit(X_train, y_train)\n",
    "    y_pred = knn.predict(X_test)\n",
    "\n",
    "    print(\"Predicted labels:\", y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_twitter",
   "language": "python",
   "name": "env_twitter"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
